import torch.nn as nn
import torch.nn.functional as f
import torch.optim as optim
import torch
from network.predict_net import Predict_mse

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.set_num_threads(8)
class RNN_quick(nn.Module):
    # 不加通信
    def __init__(self, input_shape, args):
        super(RNN_quick, self).__init__()
        self.args = args
        self.input_shape = input_shape
        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)

    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()

    def forward(self, obs, hidden_state):
        excute_label = False
        if len(hidden_state.shape) == 2:
            hidden_state = hidden_state.unsqueeze(0)  # 保证尺寸(1,3,64)
        if len(obs.shape) == 2:
            obs = obs.unsqueeze(0)
            excute_label = True

        #########################################

        obs_c = obs.view(-1, obs.shape[-1])

        x = f.relu(self.fc1(obs_c))
        x = x.reshape(obs.shape[0], obs.shape[1], -1)
        h_in = hidden_state
        gru_out, _ = self.rnn(x, h_in)
        gru_out_c = gru_out.reshape(-1, gru_out.shape[-1])
        q = self.fc2(gru_out_c)
        q = q.reshape(obs.shape[0], obs.shape[1], -1)  # 415
        if excute_label:
            return q[0],gru_out[0]
        return q, gru_out
    def update(self, inputs,vae_mu,vae_sigma, mask):
        pass
class RNN_preAC(nn.Module):
    # 不加通信
    def __init__(self, input_shape, args):
        super(RNN_preAC, self).__init__()
        self.args = args
        self.n_actions=args.n_actions
        self.n_agents=args.n_agents
        self.input_shape = input_shape
        self.sample=16
        self.fc1 = nn.Linear(input_shape+self.n_agents*self.n_actions, args.rnn_hidden_dim)
        self.action_list = torch.linspace(.1, 1, self.n_actions).reshape(1, -1)
        self.rnn = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.Pre_ac = Predict_mse(
            input_shape, 128, self.n_agents, False)
        output_dim=32
        self.fca=nn.Sequential(
            nn.Linear(self.n_agents*self.n_actions, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim),
        )

        self.fc2 = nn.Linear(args.rnn_hidden_dim+output_dim, args.n_actions)
        setup_seed(args.seed)
    def init_hidden(self):
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()

    def update(self, own_variable, other_variable, mask):
        return self.Pre_ac.update(own_variable, other_variable, mask)
    def get_intrinsic(self,obs,u_real,eval_h,q,out_eval_h):
        #q,out_eval_h=self.forward(obs,eval_h)
        out_eval_h=out_eval_h.roll(shifts=1,dims=1)
        out_eval_h[:,0,:]=eval_h.clone().detach()
        if self.args.beta==0:
            q_random=q
            q_new=q
        elif self.args.beta1 == 0:
            q_random=q
            q_new = self.get_q_con(obs, u_real, eval_h)

        elif self.args.beta2 == 0 :
            q_random=self.get_mean_con(obs, out_eval_h)
            q_new=q
        else:
            q_random=self.get_mean_con(obs, out_eval_h)
            q_new = self.get_q_con(obs, u_real, eval_h)
        q_pi = torch.softmax(1* q, dim=-1)
        q_new_pi = torch.softmax(1 * q_new, dim=-1)
        q_random_pi = (torch.softmax(1 * q_random, dim=-1)).mean(dim=0)
        pi_diverge=(q_pi*torch.log(q_pi/q_new_pi)).sum(dim=-1, keepdim=True)
        pi_random_diverge=(q_pi*torch.log(q_pi/q_random_pi)).sum(dim=-1, keepdim=True)
        pi_diverge=pi_diverge.reshape(-1,self.n_agents,pi_diverge.shape[1],1).permute(0,2,1,3)
        pi_random_diverge=pi_random_diverge.reshape(-1,self.n_agents,pi_random_diverge.shape[1],1).permute(0,2,1,3)

        return pi_random_diverge,pi_diverge

    def get_mean_con(self,obs,eval_h):
        u_random = torch.rand(
            (int(self.sample*obs.shape[0]*obs.shape[1]/self.n_agents), self.n_agents * self.n_agents, self.n_actions)).to(obs.device)
        eye=(1-torch.eye(self.n_agents)).reshape(-1,1).repeat(1,self.n_actions).to(obs.device)
        u_index=((u_random.max(axis=-1,keepdim=True)[0]==u_random).long()*eye).reshape(-1,self.n_agents*u_random.shape[-1])
        dim_0=obs.shape[0]
        obs=obs.repeat(self.sample,1,1)
        eval_h=eval_h.repeat(self.sample,1,1)
        q_mean,_=self.get_q_h(obs,u_index, eval_h)
        q_mean=q_mean.reshape(self.sample,dim_0,q_mean.shape[1],q_mean.shape[2])
        return q_mean
    def get_q_con(self,obs,u_real,eval_h):
        eye=(1-torch.eye(self.n_agents)).reshape(-1,1).repeat(1,self.n_actions).to(obs.device)
        u_real=(u_real.repeat(1,1,self.n_agents,1)*eye).reshape(u_real.shape[0],u_real.shape[1],self.n_agents,-1)
        u_real=u_real.permute(0,2,1,3).reshape(-1,u_real.shape[-1])
        q_new,_=self.get_q(obs,u_real, eval_h)
        return q_new   #eval_h=eval_h.reshape(-1,eval_h.shape[-1]).unsqueeze(0)
    def get_q(self,obs,ac_onehot_last, hidden_state):
        if len(hidden_state.shape) == 2:
            hidden_state = hidden_state.unsqueeze(0)  # 保证尺寸(1,3,64)
        obs_c = obs.reshape(-1, obs.shape[-1])
        new_obs_c=torch.cat((obs_c,ac_onehot_last),dim=-1)
        x_0 = f.relu(self.fc1(new_obs_c))
        #x = torch.cat((x_0, ac_onehot_last),dim=-1).unsqueeze(-1)
        x = x_0.reshape(obs.shape[0], obs.shape[1], -1)
        h_in = hidden_state
        gru_out, _ = self.rnn(x, h_in)
        gru_out_c = gru_out.reshape(-1, gru_out.shape[-1])
        ac_fc=self.fca(ac_onehot_last)
        gru_out_cat=torch.cat((gru_out_c,ac_fc),dim=-1)
        q = self.fc2(gru_out_cat)
        q_ = q.view(obs.shape[0], obs.shape[1], -1)  # 415

        return q_, gru_out
    def get_q_h(self,obs,ac_onehot_last,eval_h ):
        hidden_state = eval_h.reshape(-1, eval_h.shape[-1]).unsqueeze(0)
        obs_c = obs.reshape(-1, obs.shape[-1])
        new_obs_c=torch.cat((obs_c,ac_onehot_last),dim=-1)
        x_0 = f.relu(self.fc1(new_obs_c))
        x = x_0.unsqueeze(1)
        h_in = hidden_state
        gru_out, _ = self.rnn(x, h_in)
        ac_fc=self.fca(ac_onehot_last)

        gru_out_cat=torch.cat((gru_out,ac_fc.unsqueeze(1)),dim=-1)
        q = self.fc2(gru_out_cat)
        q_ = q.reshape(obs.shape[0], obs.shape[1], -1)  # 415

        return q_, gru_out

    def forward(self, obs, hidden_state):
        excute_label = False
        if len(hidden_state.shape) == 2:
            hidden_state = hidden_state.unsqueeze(0)  # 保证尺寸(1,3,64)
        if len(obs.shape) == 2:
            obs = obs.unsqueeze(0)
            excute_label = True
        obs_c = obs.reshape(-1, obs.shape[-1])
        agent_id=obs_c[...,-self.n_agents:].long()
        with torch.no_grad():
            Pre_ac=self.Pre_ac(obs_c).clip(min=self.action_list[0][0].item(),max=self.action_list[0,-1].item())
            ac_onehot=torch.zeros(torch.Size(list(Pre_ac.shape)+[self.n_actions])).to(obs_c.device)
            id_index=(torch.ones_like(ac_onehot)*(1-agent_id.unsqueeze(-1)))
            ac_index=((Pre_ac.reshape(-1,1)-self.action_list.to(Pre_ac.device))**2).min(dim=-1)[1]
            ac_index=ac_index.reshape(Pre_ac.shape).unsqueeze(-1)
            ac_onehot.scatter_(-1,ac_index,1)
            ac_onehot_last=(ac_onehot*id_index).reshape(obs_c.shape[0],-1)
        q, gru_out=self.get_q(obs,ac_onehot_last, hidden_state)
        if excute_label:
            return q[0],gru_out[0]
        return q, gru_out

class RNN(nn.Module):
    # Because all the agents share the same network, input_shape=obs_shape+n_actions+n_agents
    def __init__(self, input_shape, args):
        super(RNN, self).__init__()
        self.args = args

        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)

    def forward(self, obs, hidden_state):
        x = f.relu(self.fc1(obs))
        h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        q = self.fc2(h)
        return q, h

class RNN_utility(nn.Module):
    # 不加通信
    def __init__(self, input_shape, args, h_dim=32, z_dim=16):
        super(RNN_utility, self).__init__()

        self.args = args
        self.h_dim=h_dim
        self.n_actions=args.n_actions
        self.input_shape = input_shape
        self.eps_rate=args.var_rate
        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.rnn = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)

        #########encoder net#############
        self.en_fc1 = nn.Linear(input_shape, h_dim)
        self.en_fc2 = nn.Linear(h_dim, z_dim)  # mu
        self.en_fc3 = nn.Linear(h_dim, z_dim)  # log_var

        ########latent$$$$$$$$$$$$$
        self.latent_net=nn.Linear(z_dim, h_dim)
        self.fc2_w_nn=nn.Linear(h_dim, args.rnn_hidden_dim*args.n_actions)
        self.fc2_b_nn=nn.Linear(h_dim, args.n_actions)
        self.optimizer = optim.Adam(self.parameters(), lr=1e-3)
        self.loss_weight=args.wloss_rate


    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()
    def reparameterization(self, mu, sigma):
        """
        Given a standard gaussian distribution epsilon ~ N(0,1),
        we can sample the random variable z as per z = mu + sigma * epsilon
        :param mu:
        :param log_var:
        :return: sampled z
        """
         #0.01
        eps = torch.randn_like(sigma)
        return mu + sigma * eps  # 这里的“*”是点乘的意思
    def encode(self, x):
        """
        encoding part
        :param x: input image
        :return: mu and log_var
        """
        h = f.relu(self.en_fc1(x))
        mu = self.en_fc2(h)
        log_var = self.en_fc3(h)
        sigma =self.eps_rate* torch.exp(log_var * 0.5)
        return mu, sigma

    def forward(self, obs, hidden_state):
        excute_label = False
        if len(hidden_state.shape) == 2:
            hidden_state = hidden_state.unsqueeze(0)  # 保证尺寸(1,3,64)
        if len(obs.shape) == 2:
            obs =obs.unsqueeze(0)
            excute_label=True
        #########################################

        obs_c = obs.view(-1, obs.shape[-1])

        x = f.relu(self.fc1(obs_c))
        x = x.reshape(obs.shape[0], obs.shape[1], -1)
        h_in = hidden_state
        gru_out, _ = self.rnn(x, h_in)
        #gru_out_c = gru_out.reshape(-1, gru_out.shape[-1])
        with torch.no_grad():
            mu, sigma = self.encode(obs_c)
            sampled_z = self.reparameterization(mu, sigma)
        # mu, sigma = self.encode(obs_c)
        # sampled_z = self.reparameterization(mu, sigma)
        latent=self.latent_net(sampled_z)
        fc2_w = self.fc2_w_nn(latent)
        fc2_b = self.fc2_b_nn(latent)
        fc2_w = fc2_w.reshape(-1, self.args.rnn_hidden_dim, self.args.n_actions)
        fc2_b = fc2_b.reshape((-1, 1, self.args.n_actions))
        h = gru_out.reshape(-1, 1, gru_out.shape[-1])
        q = torch.bmm(h, fc2_w) + fc2_b
        #q = self.fc2(gru_out_c)
        q = q.reshape(obs.shape[0], obs.shape[1], -1)  # 415
        if excute_label:
            return q[0], gru_out[0]
        return q, gru_out
    def EMD_Distance(self,mu1,sigma_1,mu2,sigma_2):
        dst_1=((mu2-mu1)**2).sum(axis=-1)
        dst_2=((sigma_1**0.5-sigma_2**0.5)**2).sum(axis=-1)
        return dst_1+dst_2
    def update(self, inputs,vae_mu,vae_sigma, mask):
        if mask.sum() > 0:
            mu, sigma = self.encode(inputs)
            # latent_normal=D.Normal(mu,sigma)
            # vae_normal=D.Normal(vae_mu,vae_sigma)
            loss=self.loss_weight*self.EMD_Distance(mu,sigma,vae_mu,vae_sigma)
            # loss= kl_divergence(latent_normal,vae_normal).mean(axis=-1) #+1*CGE
            loss = (loss * mask).sum() / mask.sum()
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.parameters(), 1.)
            self.optimizer.step()

# Critic of Central-V
class Critic(nn.Module):
    def __init__(self, input_shape, args):
        super(Critic, self).__init__()
        self.args = args
        self.fc1 = nn.Linear(input_shape, args.critic_dim)
        self.fc2 = nn.Linear(args.critic_dim, args.critic_dim)
        self.fc3 = nn.Linear(args.critic_dim, 1)

    def forward(self, inputs):
        x = f.relu(self.fc1(inputs))
        x = f.relu(self.fc2(x))
        q = self.fc3(x)
        return q
class MLP(nn.Module):

    def __init__(self, args):
        super(MLP, self).__init__()
        self.args = args
        self.fc = nn.Linear(args.rnn_hidden_dim, args.n_actions)

    def forward(self, hidden_state):
        q = self.fc(hidden_state)
        return q
